/*
 * BeliefPropagation.cpp
 *
 *  Created on: Dec 12, 2013
 *      Author: nino
 */

#include <BeliefPropagation.h>


void DMP_algorithm(double p, double q, infected_structure * p_infected_parameters, int T_value, igraph_vector_t * prob_inf_DMP){

	// INIT
	igraph_spmatrix_t * P_S_now, * P_S_next, * P_S_last;
	igraph_spmatrix_t P_S_0, P_S_1;

	igraph_spmatrix_init( &P_S_0, p_infected_parameters->no_of_nodes, p_infected_parameters->no_of_nodes );
	igraph_spmatrix_init( &P_S_1, p_infected_parameters->no_of_nodes, p_infected_parameters->no_of_nodes );

	igraph_spmatrix_t * Phi_now, * Phi_last;
	igraph_spmatrix_t Phi_0, Phi_1;

	igraph_spmatrix_init( &Phi_0, p_infected_parameters->no_of_nodes, p_infected_parameters->no_of_nodes );
	igraph_spmatrix_init( &Phi_1, p_infected_parameters->no_of_nodes, p_infected_parameters->no_of_nodes );

	igraph_spmatrix_t * Theta_next, * Theta_now;
	igraph_spmatrix_t Theta_0, Theta_1;

	igraph_spmatrix_init( &Theta_0, p_infected_parameters->no_of_nodes, p_infected_parameters->no_of_nodes );
	igraph_spmatrix_init( &Theta_1, p_infected_parameters->no_of_nodes, p_infected_parameters->no_of_nodes );

	igraph_vector_t prob_S_1, prob_R_1, prob_I_1, prob_R_0, prob_I_0;
	igraph_vector_t *prob_S_next, *prob_R_next, *prob_I_next, *prob_R_now, *prob_I_now;

	igraph_vector_init ( &prob_S_1, p_infected_parameters->no_of_nodes);
	igraph_vector_init ( &prob_R_1, p_infected_parameters->no_of_nodes);
	igraph_vector_init ( &prob_I_1, p_infected_parameters->no_of_nodes);
	igraph_vector_init ( &prob_R_0, p_infected_parameters->no_of_nodes);
	igraph_vector_init ( &prob_I_0, p_infected_parameters->no_of_nodes);


	// DMP start
	for( int time = 0; time < T_value; time++){
		if (time == 0){

			Theta_now = & Theta_0;
			Phi_now = & Phi_0;
			prob_I_now = & prob_I_0;
			initial_conditions_DMP(Theta_now, Phi_now, prob_I_now, p_infected_parameters);

			P_S_now = & P_S_0;
			update_P_S_DMP(P_S_now, Theta_now, p_infected_parameters);
			//Initial conditions set !

			Theta_next = & Theta_1;
			update_theta_DMP(p, Theta_next, Theta_now, Phi_now, p_infected_parameters);

			P_S_next = & P_S_1;

			update_P_S_DMP(P_S_next, Theta_next, p_infected_parameters);

			prob_S_next = & prob_S_1;
			update_marginal_prob_S_DMP( prob_S_next, Theta_next, p_infected_parameters);
			prob_R_next = & prob_R_1;
			prob_R_now = & prob_R_0;
			update_marginal_prob_R_DMP(q, prob_R_next, prob_R_now, prob_I_now);
			prob_I_next = & prob_I_1;
			update_marginal_prob_I_DMP(prob_S_next, prob_I_next, prob_R_next);
			// Theta_next and P_S_next at time %d calculated !


			// Time swap of pointers for next time step
			Phi_now = & Phi_1;
			Phi_last = & Phi_0;
			P_S_now = & P_S_1;
			P_S_last = & P_S_0;
			Theta_next = & Theta_0; //Overwriting old
			Theta_now = & Theta_1;
			P_S_next = & P_S_0; //Overwriting old

			prob_I_next = & prob_I_0;
			prob_I_now = & prob_I_1;
			prob_R_next = & prob_R_0;
			prob_R_now = & prob_R_1;

		}else{

			update_phi_DMP(p, q, Phi_now, Phi_last, P_S_now, P_S_last, p_infected_parameters);

			update_theta_DMP(p, Theta_next, Theta_now, Phi_now, p_infected_parameters);

			//update_P_S_DMP_slow(P_S_next, Theta_next, p_infected_parameters);
			//update_P_S_DMP_normal(P_S_next, Theta_next, p_infected_parameters);
			update_P_S_DMP(P_S_next, Theta_next, p_infected_parameters);
			// faster with log likelihood

			update_marginal_prob_S_DMP( prob_S_next, Theta_next, p_infected_parameters);

			update_marginal_prob_R_DMP(q, prob_R_next, prob_R_now, prob_I_now);

			update_marginal_prob_I_DMP(prob_S_next, prob_I_next, prob_R_next);
			//Theta_next, P_S_next and Phi_now at time calculated

			// Time swap of pointers for next time step
			// !!! Caution we are at time step but are making changes for next moment
			if ( time % 2 != 0){
				Phi_now = & Phi_0;
				Phi_last = & Phi_1;
				P_S_now = & P_S_0;
				P_S_last = & P_S_1;
				Theta_next = & Theta_1;
				Theta_now = & Theta_0;
				P_S_next = & P_S_1;

				prob_I_next = & prob_I_0;
				prob_I_now = & prob_I_1;
				prob_R_next = & prob_R_0;
				prob_R_now = & prob_R_1;
			}else{
				Phi_now = & Phi_1;
				Phi_last = & Phi_0;
				P_S_now = & P_S_1;
				P_S_last = & P_S_0;
				Theta_next = & Theta_0;
				Theta_now = & Theta_1;
				P_S_next = & P_S_0;

				prob_I_next = & prob_I_1;
				prob_I_now = & prob_I_0;
				prob_R_next = & prob_R_1;
				prob_R_now = & prob_R_0;
			}


		}
	}
	// DMP algorithm END

	// We output the probability that node has been infected or recoverd
	for(long int node_tmp=0; node_tmp < igraph_vector_size(prob_inf_DMP); ++node_tmp){
		VECTOR(*prob_inf_DMP)[node_tmp] = (igraph_real_t) 1.0 - VECTOR(*prob_S_next)[node_tmp];
	}

	igraph_spmatrix_destroy( &P_S_0 );
	igraph_spmatrix_destroy( &P_S_1 );
	igraph_spmatrix_destroy( &Phi_0 );
	igraph_spmatrix_destroy( &Phi_1 );
	igraph_spmatrix_destroy( &Theta_0 );
	igraph_spmatrix_destroy( &Theta_1 );

	igraph_vector_destroy ( &prob_S_1 );
	igraph_vector_destroy ( &prob_I_1 );
	igraph_vector_destroy ( &prob_R_1 );
	igraph_vector_destroy ( &prob_R_0 );
	igraph_vector_destroy ( &prob_I_0 );

}

void initial_conditions_DMP(igraph_spmatrix_t * Theta_now, igraph_spmatrix_t * Phi_now, igraph_vector_t *prob_I_now, infected_structure * p_infected_parameters){


	// Initial conditions theta
	for ( long int node_tmp; node_tmp < p_infected_parameters->no_of_nodes; ++node_tmp){
		igraph_vector_t *neis = igraph_adjlist_get(p_infected_parameters->al, node_tmp);

		for (int i = 0; i < (int) igraph_vector_size(neis); ++i) {
			long int current_neigh = (long int) VECTOR(*neis)[i];
			//printf("%d->%d ", (int) node_tmp, (int) current_neigh);
			igraph_spmatrix_add_e( Theta_now, node_tmp, current_neigh , (igraph_real_t) 1.0 );
		}
	}

	// Initial conditions phi
	igraph_vector_t *neis = igraph_adjlist_get(p_infected_parameters->al, p_infected_parameters->start_node);
	for (int i = 0; i < (int) igraph_vector_size(neis); ++i) {
		long int current_neigh = (long int) VECTOR(*neis)[i];
		igraph_spmatrix_add_e( Phi_now, p_infected_parameters->start_node, current_neigh , (igraph_real_t) 1.0 );
	}

	VECTOR(*prob_I_now)[p_infected_parameters->start_node] = (igraph_real_t) 1.0;
}

void update_P_S_DMP(igraph_spmatrix_t * P_S_next, igraph_spmatrix_t * Theta_next, infected_structure * p_infected_parameters){

	for ( long int node_tmp; node_tmp < p_infected_parameters->no_of_nodes; ++node_tmp){
			igraph_vector_t *neis = igraph_adjlist_get(p_infected_parameters->al, node_tmp);
			double log_neigh_product_total = 0.0;
			// We use approximation : find product for all neighbours and then exclude particular
			// If p is high then this approximation maybe can have numerical problems because we can have theta_ij \approx 0

			for (int i = 0; i < (int) igraph_vector_size(neis); ++i) {
				long int current_neigh = (long int) VECTOR(*neis)[i];

				igraph_real_t theta_tmp_now = igraph_spmatrix_e( Theta_next, current_neigh, node_tmp);
				log_neigh_product_total += log( (double) theta_tmp_now);
			}

			for (int i = 0; i < (int) igraph_vector_size(neis); ++i) {
					long int current_neigh = (long int) VECTOR(*neis)[i];

					igraph_real_t theta_tmp_now = igraph_spmatrix_e( Theta_next, current_neigh, node_tmp);
					double tmp_log_neigh_product = log_neigh_product_total - log( (double) theta_tmp_now );

					double tmp_neigh_product = exp ( tmp_log_neigh_product );

					if (node_tmp == p_infected_parameters->start_node){
						// P_S_node_tmp (0) = 0;
						tmp_neigh_product *= 0.0;

					}else{
						// P_S_node_tmp (0) = 1;
						tmp_neigh_product *= 1.0;
					}

					igraph_spmatrix_set( P_S_next, node_tmp, current_neigh , (igraph_real_t) tmp_neigh_product );
			}
	}

}

void update_theta_DMP(double p, igraph_spmatrix_t * Theta_next, igraph_spmatrix_t * Theta_now, igraph_spmatrix_t* Phi_now, infected_structure * p_infected_parameters){

	for ( long int node_tmp; node_tmp < p_infected_parameters->no_of_nodes; ++node_tmp){
		igraph_vector_t *neis = igraph_adjlist_get(p_infected_parameters->al, node_tmp);

		for (int i = 0; i < (int) igraph_vector_size(neis); ++i) {
			long int current_neigh = (long int) VECTOR(*neis)[i];

			igraph_real_t theta_tmp_now = igraph_spmatrix_e( Theta_now, node_tmp, current_neigh);
			igraph_real_t phi_tmp_now = igraph_spmatrix_e( Phi_now, node_tmp, current_neigh);

			igraph_real_t lambda = (igraph_real_t) p;
			igraph_real_t theta_tmp_update =  theta_tmp_now - (lambda * phi_tmp_now);

			igraph_spmatrix_set( Theta_next, node_tmp, current_neigh , theta_tmp_update );
		}
	}

}

void update_phi_DMP(double p, double q, igraph_spmatrix_t * Phi_now, igraph_spmatrix_t * Phi_last, igraph_spmatrix_t* P_S_now, igraph_spmatrix_t* P_S_last, infected_structure * p_infected_parameters){

	igraph_real_t lambda = (igraph_real_t) p;
	igraph_real_t mi = (igraph_real_t) q;


	for ( long int node_tmp; node_tmp < p_infected_parameters->no_of_nodes; ++node_tmp){
			igraph_vector_t *neis = igraph_adjlist_get(p_infected_parameters->al, node_tmp);

			for (int i = 0; i < (int) igraph_vector_size(neis); ++i) {
				long int current_neigh = (long int) VECTOR(*neis)[i];

				igraph_real_t phi_tmp_last = igraph_spmatrix_e( Phi_last, node_tmp, current_neigh);

				igraph_real_t p_s_tmp_now = igraph_spmatrix_e( P_S_now, node_tmp, current_neigh);
				igraph_real_t p_s_tmp_last = igraph_spmatrix_e( P_S_last, node_tmp, current_neigh);

				igraph_real_t phi_tmp_update =  (1-lambda)*(1-mi)*phi_tmp_last - (p_s_tmp_now - p_s_tmp_last);

				igraph_spmatrix_set( Phi_now, node_tmp, current_neigh , phi_tmp_update );
			}
	}

}

void update_marginal_prob_S_DMP(igraph_vector_t *prob_S_next, igraph_spmatrix_t *Theta_next, infected_structure *p_infected_parameters){

	for(long int node_tmp=0; node_tmp < igraph_vector_size(prob_S_next); ++node_tmp){
		if (node_tmp != p_infected_parameters->start_node){
			// P_S_i (0) = 1
			igraph_vector_t *neis = igraph_adjlist_get(p_infected_parameters->al, node_tmp);
			igraph_real_t prob_neigh_prod = (igraph_real_t) 1.0;

			for (int i = 0; i < (int) igraph_vector_size(neis); ++i) {
				long int current_neigh = (long int) VECTOR(*neis)[i];
				prob_neigh_prod *= igraph_spmatrix_e( Theta_next, current_neigh, node_tmp);
			}

			VECTOR(*prob_S_next)[node_tmp] = prob_neigh_prod;
		}else{
			// P_S_i (0) = 0
			VECTOR(*prob_S_next)[node_tmp] = (igraph_real_t) 0.0;
		}
	}

}

void update_marginal_prob_R_DMP(double q, igraph_vector_t *prob_R_next, igraph_vector_t *prob_R_now, igraph_vector_t *prob_I_now){

	igraph_real_t mi = (igraph_real_t) q;

	for(long int node_tmp=0; node_tmp < igraph_vector_size(prob_R_next); ++node_tmp){
		igraph_real_t prob_r_now_tmp = VECTOR(*prob_R_now)[node_tmp];
		igraph_real_t prob_i_now_tmp = VECTOR(*prob_I_now)[node_tmp];
		VECTOR(*prob_R_next)[node_tmp] = prob_r_now_tmp + mi * prob_i_now_tmp;
	}

}

void update_marginal_prob_I_DMP(igraph_vector_t *prob_S_next, igraph_vector_t *prob_I_next, igraph_vector_t *prob_R_next){

	for(long int node_tmp=0; node_tmp < igraph_vector_size(prob_R_next); ++node_tmp){
		VECTOR(*prob_I_next)[node_tmp] = ( (igraph_real_t) 1.0 ) - ( VECTOR(*prob_S_next)[node_tmp] ) - ( VECTOR(*prob_R_next)[node_tmp] );
	}
}
